{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "801db364",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import coremltools as ct\n",
    "import clip\n",
    "import numpy as np\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f7dcff",
   "metadata": {},
   "source": [
    "# 1. Export TextEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f89976b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast\n",
    "\n",
    "\n",
    "model_id = \"openai/clip-vit-base-patch32\"\n",
    "model = CLIPTextModelWithProjection.from_pretrained(model_id, return_dict=False)\n",
    "tokenizer = CLIPTokenizerFast.from_pretrained(model_id)\n",
    "model.eval()\n",
    "\n",
    "example_input = tokenizer(\"a photo of a cat\", return_tensors=\"pt\")\n",
    "example_input = example_input.data['input_ids']\n",
    "\n",
    "traced_model = torch.jit.trace(model, example_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87abd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 76 # if max_seq_length is 77 as in the original model, the validation fails, see details at the end of the notebook. Set max_seq_length to 76 works fine with the app.\n",
    "text_encoder_model = ct.convert(\n",
    "            traced_model,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[ct.TensorType(name=\"prompt\",\n",
    "                                 shape=[1,max_seq_length],\n",
    "                                 dtype=np.int32)],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32),\n",
    "                     ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n",
    "        )\n",
    "text_encoder_model.save(\"TextEncoder_float32_test.mlpackage\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "617e4e6b",
   "metadata": {},
   "source": [
    "## Validate export  precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6af02a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the model\n",
    "model = ct.models.MLModel('TextEncoder_float32_test.mlpackage')\n",
    "\n",
    "# Choose a tokenizer, here we use the clip tokenizer\n",
    "text = clip.tokenize(\"a photo of a cat\")\n",
    "text = text[:,:max_seq_length]\n",
    "\n",
    "# # Or use CLIPTokenizerFast\n",
    "# text = tokenizer(\"a photo of a cat\", return_tensors=\"pt\", padding=\"max_length\", max_length=max_seq_length)\n",
    "# text = text.data['input_ids'].to(torch.int32)\n",
    "\n",
    "predictions = model.predict({'prompt': text})\n",
    "out = traced_model(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29d0a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"PyTorch TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", out[0][0, :10])\n",
    "print(\"\\nCoreML TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c0d9c70",
   "metadata": {},
   "source": [
    "You can see that there is some loss in precision, but it is still acceptable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca182b4a",
   "metadata": {},
   "source": [
    "# 2. Export ImageEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68521589",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPVisionModelWithProjection, CLIPProcessor\n",
    "\n",
    "model_id = \"openai/clip-vit-base-patch32\"\n",
    "model = CLIPVisionModelWithProjection.from_pretrained(model_id, return_dict=False)\n",
    "processor = CLIPProcessor.from_pretrained(model_id)\n",
    "model.eval()\n",
    "\n",
    "img = Image.open(\"love-letters-and-hearts.jpg\")\n",
    "example_input = processor(images=img, return_tensors=\"pt\")\n",
    "example_input = example_input['pixel_values']\n",
    "traced_model = torch.jit.trace(model, example_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304ae7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "bias = [-processor.image_processor.image_mean[i]/processor.image_processor.image_std[i] for i in range(3)]\n",
    "scale = 1.0 / (processor.image_processor.image_std[0] * 255.0)\n",
    "\n",
    "image_input_scale = ct.ImageType(name=\"colorImage\",\n",
    "                           color_layout=ct.colorlayout.RGB,\n",
    "                           shape=example_input.shape,\n",
    "                           scale=scale, bias=bias,\n",
    "                           channel_first=True,)\n",
    "\n",
    "image_encoder_model = ct.convert(\n",
    "            traced_model,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[image_input_scale],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32), \n",
    "                     ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n",
    "        )\n",
    "\n",
    "image_encoder_model.save(\"ImageEncoder_float32.mlpackage\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3c5008e",
   "metadata": {},
   "source": [
    "## Validate export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759bb57d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "\n",
    "image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')\n",
    "imgPIL = Image.open(\"love-letters-and-hearts.jpg\")\n",
    "imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)\n",
    "\n",
    "img_np = np.asarray(imgPIL).astype(np.float32) # (224, 224, 3)\n",
    "img_np = img_np[np.newaxis, :, :, :] # (1, 224, 224, 3)\n",
    "img_np = np.transpose(img_np, [0, 3, 1, 2]) # (1, 3, 224, 224)\n",
    "img_np = img_np / 255.0\n",
    "torch_tensor_input = torch.from_numpy(img_np)\n",
    "transform_model = torch.nn.Sequential(\n",
    "        transforms.Normalize(mean=processor.image_processor.image_mean,\n",
    "                             std=processor.image_processor.image_std),\n",
    ")\n",
    "\n",
    "predictions = image_encoder.predict({'colorImage': imgPIL})\n",
    "out = traced_model(transform_model(torch_tensor_input))\n",
    "print(\"PyTorch ImageEncoder ckpt out for jpg:\\n>>>\", out[0][0, :10])\n",
    "print(\"\\nCoreML ImageEncoder ckpt out for jpg:\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aac310d4",
   "metadata": {},
   "source": [
    "## Test result for max_length = 77"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f24ec713",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import CLIPTextModelWithProjection, CLIPTokenizerFast\n",
    "\n",
    "\n",
    "model_id = \"openai/clip-vit-base-patch32\"\n",
    "model = CLIPTextModelWithProjection.from_pretrained(model_id, return_dict=False)\n",
    "tokenizer = CLIPTokenizerFast.from_pretrained(model_id)\n",
    "model.eval()\n",
    "\n",
    "example_input = tokenizer(\"a photo of a cat\", return_tensors=\"pt\")\n",
    "example_input = example_input.data['input_ids']\n",
    "\n",
    "traced_model = torch.jit.trace(model, example_input)\n",
    "\n",
    "max_seq_length = 77 # if max_seq_length is 77 as in the original model, the validation fails, see details below. Set max_seq_length to 76 works fine with the app.\n",
    "text_encoder_model = ct.convert(\n",
    "            traced_model,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[ct.TensorType(name=\"prompt\",\n",
    "                                 shape=[1,max_seq_length],\n",
    "                                 dtype=np.int32)],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32),\n",
    "                     ct.TensorType(name=\"embOutput2\", dtype=np.float32)],\n",
    "        )\n",
    "\n",
    "# Choose a tokenizer, here we use the clip tokenizer\n",
    "text = clip.tokenize(\"a photo of a cat\")\n",
    "text = text[:,:max_seq_length]\n",
    "\n",
    "predictions = text_encoder_model.predict({'prompt': text})\n",
    "out = traced_model(text)\n",
    "\n",
    "print(\"PyTorch TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", out[0][0, :10])\n",
    "print(\"\\nCoreML TextEncoder ckpt out for \\\"a photo of a cat\\\":\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
